#include <acl/acl.h>
#include <cassert>
#include <cmath>
#include <cstdio>
#include "dataset.h"
#include "error_check.h"
using std::printf;

const char* const modelPath = "/home/HwHiAiUser/mnist/model/onnx_lenet.om";

int main() {
  int32_t majorVersion, minorVersion, patchVersion;
  CHECK(aclrtGetVersion(&majorVersion, &minorVersion, &patchVersion));
  printf("ACL version: %d.%d.%d\n", majorVersion, minorVersion, patchVersion);
  const char* socName = aclrtGetSocName();
  CHECK_NULLPTR(socName);
  printf("SoC Name: %s\n", socName);

  CHECK(aclInit(nullptr));

  uint32_t deviceCount = 0;
  const int32_t deviceId = 0;

  CHECK(aclrtGetDeviceCount(&deviceCount));
  assert(deviceCount > deviceId);
  printf("Device count: %d\n", deviceCount);

  // Set our device
  CHECK(aclrtSetDevice(deviceId));

  // Create Context
  aclrtContext context = nullptr;
  CHECK(aclrtCreateContext(&context, deviceId));

  // Create Stream
  aclrtStream stream = nullptr;
  CHECK(aclrtCreateStream(&stream));

  // Load model
  uint32_t modelId = -1;
  CHECK(aclmdlLoadFromFile(modelPath, &modelId));

  // Get Runmode
  aclrtRunMode runMode;
  CHECK(aclrtGetRunMode(&runMode));

  if (runMode == ACL_DEVICE)
    puts("ACL running on DEVICE or development board.");
  else if (runMode == ACL_HOST)
    puts("ACL running on HOST.");

  aclmdlDesc* modelDesc = aclmdlCreateDesc();

  CHECK_NULLPTR(modelDesc);

  CHECK(aclmdlGetDesc(modelDesc, modelId));

  // main code below

  size_t numInput = 0, numOutput = 0;
  numInput = aclmdlGetNumInputs(modelDesc);
  numOutput = aclmdlGetNumOutputs(modelDesc);
  printf("LeNet has %lu input, %lu output.\n", numInput, numOutput);

  aclmdlDataset* inputDataset = aclmdlCreateDataset();
  aclmdlDataset* outputDataset = aclmdlCreateDataset();

  CHECK_NULLPTR(inputDataset);
  CHECK_NULLPTR(outputDataset);

  size_t inputSize = 0, outputSize = 0;
  inputSize = aclmdlGetInputSizeByIndex(modelDesc, 0);
  outputSize = aclmdlGetOutputSizeByIndex(modelDesc, 0);

  printf("LeNet accepts input of size %lu, outputs a tensor with size %lu.\n",
         inputSize, outputSize);
  // LeNet accepts input of size 4096 (height x width x sizeof(float))
  // outputs a tensor with size 40 (num_class x sizeof(float)).

  unsigned char *imageData, *labelData;
  getDataset(&imageData, &labelData);
  int demoIdx = 51;
  colorShell(&imageData[demoIdx * 28 * 28]);
  printf("Label: %d\n", (int)labelData[demoIdx]);

  float *inputBuffer, *outputBuffer;

  CHECK(
      aclrtMalloc((void**)&inputBuffer, inputSize, ACL_MEM_MALLOC_NORMAL_ONLY));
  CHECK(aclrtMalloc((void**)&outputBuffer, outputSize,
                    ACL_MEM_MALLOC_NORMAL_ONLY));
  for (int i = 0; i < 28; ++i)
    for (int j = 0; j < 28; ++j) {
      inputBuffer[i * 28 + j] = (float)imageData[i * 28 + j] / 255.0;
    }

  aclDataBuffer* inDataBuffer = aclCreateDataBuffer(inputBuffer, inputSize);
  aclDataBuffer* outDataBuffer = aclCreateDataBuffer(outputBuffer, outputSize);

  CHECK_NULLPTR(inDataBuffer);
  CHECK_NULLPTR(outDataBuffer);

  CHECK(aclmdlAddDatasetBuffer(inputDataset, inDataBuffer));
  CHECK(aclmdlAddDatasetBuffer(outputDataset, outDataBuffer));

  CHECK(aclmdlExecute(modelId, inputDataset, outputDataset));

  // Softmax the output

  for (int i = 0; i < 10; ++i) outputBuffer[i] = exp(outputBuffer[i]);
  float sum_output = 0;
  for (int i = 0; i < 10; ++i) sum_output += outputBuffer[i];
  for (int i = 0; i < 10; ++i) outputBuffer[i] = outputBuffer[i] / sum_output;

  float maxval = outputBuffer[0];
  int maxidx = 0;
  for (int i = 1; i < 10; ++i) {
    if (float val = outputBuffer[i]; val > maxval) maxidx = i, maxval = val;
    sum_output += outputBuffer[i];
  }
  printf("Our model predicts the digit to be %d with a confidence of %f\n",
         maxidx, maxval / sum_output);

  // main code above

  CHECK(aclDestroyDataBuffer(inDataBuffer));
  CHECK(aclDestroyDataBuffer(outDataBuffer));

  CHECK(aclrtFree(inputBuffer));
  CHECK(aclrtFree(outputBuffer));

  CHECK(aclmdlDestroyDataset(inputDataset));
  CHECK(aclmdlDestroyDataset(outputDataset));
  CHECK(aclmdlDestroyDesc(modelDesc));

  CHECK(aclrtDestroyStream(stream));
  CHECK(aclrtDestroyContext(context));
  CHECK(aclrtResetDevice(deviceId));
  CHECK(aclFinalize());
}
